Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for linear-time mmd estimator. #475

Open
wants to merge 18 commits into
base: master
Choose a base branch
from

Conversation

Srceh
Copy link

@Srceh Srceh commented Apr 1, 2022

This PR implements the linear-time estimator in (Lemma14 in paper), as asked in #288.

@Srceh Srceh requested a review from arnaudvl April 1, 2022 13:02
@Srceh Srceh marked this pull request as ready for review April 1, 2022 13:05
@Srceh Srceh removed the request for review from arnaudvl April 1, 2022 13:37
@Srceh Srceh marked this pull request as draft April 1, 2022 13:37
@ojcobb
Copy link
Contributor

ojcobb commented Apr 1, 2022

Do we think users are ever going to have equal reference and test batch sizes in practice? I'd guess almost always the reference set is going to be much larger. I wonder if we'd be better off using the B-stat estimator by default for the linear case rather than Gretton's estimator for equal sample sizes. This additionally has the advantage of a tunable parameter that allows for interpolation between a linear and quadratic time estimator.

@arnaudvl @Srceh

Edit: It's actually not quite this simple. However I think we should put some thought into how best to address the n!=m case.

@Srceh
Copy link
Author

Srceh commented Apr 1, 2022

Do we think users are ever going to have equal reference and test batch sizes in practice? I'd guess almost always the reference set is going to be much larger. I wonder if we'd be better off using the B-stat estimator by default for the linear case rather than Gretton's estimator for equal sample sizes. This additionally has the advantage of a tunable parameter that allows for interpolation between a linear and quadratic time estimator.

@arnaudvl @Srceh

Edit: It's actually not quite this simple. However I think we should put some thought into how best to address the n!=m case.

Agree, maybe we can leave the current PR as it is for the linear-time one, and do a separate one for the additional B-stat implementation.

@arnaudvl
Copy link
Contributor

arnaudvl commented Apr 1, 2022

@Srceh @ojcobb Thinking if it wouldn't be cleaner to have a separate LinearMMDDrift and BMMDDrift (for lack of a better name) instead of grouping everything in the existing MMD implementation. It would be a bit easier to debug as well and can just share the MMD base class.

@Srceh Srceh marked this pull request as ready for review April 1, 2022 15:35
@Srceh Srceh requested review from arnaudvl and ojcobb April 4, 2022 15:15
…eshold with the linear-time estimator, instead of permutation.
@Srceh
Copy link
Author

Srceh commented Apr 6, 2022

Now the linear-time estimator also uses Gaussian under null for the test threshold, so no permutation is required. It should be the fastest at the cost of lower test power and some unused samples.

def forward(self, x: Union[np.ndarray, torch.Tensor],
y: Union[np.ndarray, torch.Tensor],
infer_sigma: bool = False,
diag: bool = False) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that they refer to the same thing, perhaps we could keep consistency between this kwarg name and the naming convention adopted for the squared distance functions? So perhaps pairwise: bool = True?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, now uses pairwise as suggested.


if infer_sigma or self.init_required:
if self.trainable and infer_sigma:
raise ValueError("Gradients cannot be computed w.r.t. an inferred sigma value")
sigma = self.init_sigma_fn(x, y, dist)
if not diag:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a good default behaviour to have? Could we end up with O(n^2) costs in places where the linear time estimator is being used specifically because such a cost would be infeasible?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now directly use the median of the non-pairwise distance.

@@ -69,15 +69,24 @@ def __init__(
def sigma(self) -> tf.Tensor:
return tf.math.exp(self.log_sigma)

def call(self, x: tf.Tensor, y: tf.Tensor, infer_sigma: bool = False) -> tf.Tensor:
def call(self, x: tf.Tensor, y: tf.Tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comments on pytorch version

@@ -93,7 +115,43 @@ def batch_compute_kernel_matrix(
return k_mat


def mmd2_from_kernel_matrix(kernel_mat: torch.Tensor, m: int, permute: bool = False,
def linear_mmd2(x: torch.Tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick but probs worth keeping indentation within function definitions consistent with all of the other functions.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, should be consistent all across now.

@@ -18,6 +18,7 @@ def __init__(
x_ref: Union[np.ndarray, list],
backend: str = 'tensorflow',
p_val: float = .05,
estimator: str = 'quad',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would estimator_complexity be more descriptive? (Or at least make clear in the docstring)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added extra description in the docstring.

k_yz = kernel(x=y[0::2, :], y=x[1::2, :], diag=True)

h = k_xx + k_yy - k_xy - k_yz
mmd2 = h.sum() / (n / 2.)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we don't just use h.mean() and h.var()?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now uses h.mean(), and torch.var(, unbiased=True) in the torch version. TF version uses tf.reduce_mean and manual correction.

def linear_mmd2(x: tf.Tensor,
y: tf.Tensor,
kernel: Callable,
permute: bool = False) -> Tuple[tf.Tensor, tf.Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we offer permute option for tensorflow and not torch?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Legacy issue, now removed for the tensorflow version.

Comment on lines 144 to 147
k_xx = kernel(x_hat[0::2, :], x_hat[1::2, :], diag=True)
k_yy = kernel(y_hat[0::2, :], y_hat[1::2, :], diag=True)
k_xy = kernel(x_hat[0::2, :], y_hat[1::2, :], diag=True)
k_yz = kernel(y_hat[0::2, :], x_hat[1::2, :], diag=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like unnecessary duplication

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

mmd2 = mmd2.numpy().item()
var_mmd2 = var_mmd2.numpy().item()
std_mmd2 = np.sqrt(var_mmd2)
p_val = 1 - stats.norm.cdf(mmd2 * np.sqrt(n_hat), loc=0., scale=std_mmd2*np.sqrt(2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick but should this be a t-test?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice spot, now fixed with t-test for both versions.

mmd2 = mmd2.cpu()
mmd2 = mmd2.numpy().item()
var_mmd2 = var_mmd2.numpy().item()
std_mmd2 = np.sqrt(var_mmd2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can directly use torch.std(...) in linear_mmd2? This would remove the few additional lines of code here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new version uses np.sqrt(np.clip(var_mmd2, 1e-8, 1e-8)) for numeric stability.

@@ -30,6 +30,28 @@ def squared_pairwise_distance(x: torch.Tensor, y: torch.Tensor, a_min: float = 1
return dist.clamp_min_(a_min)


def squared_distance(x: torch.Tensor, y: torch.Tensor, a_min: float = 1e-30) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just apply a reduction to the squared_pairwise_distance instead of using an extra function?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now implemented as a single function.

m = np.shape(y)[0]
if n != m:
raise RuntimeError("Linear-time estimator requires equal size samples")
k_xx = kernel(x=x[0::2, :], y=x[1::2, :], pairwise=False)
Copy link
Contributor

@arnaudvl arnaudvl Apr 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to do this at init time (so self.k_xx becomes useful again), saving compute at prediction time.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, now the kernel matrix is reused for prediction.

"""
n = np.shape(x)[0]
m = np.shape(y)[0]
if n != m:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This behaviour should in my opinion already be checked beforehand (see comment in the method itself).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

k_xx = kernel(x=x[0::2, :], y=x[1::2, :], pairwise=False)
k_yy = kernel(x=y[0::2, :], y=y[1::2, :], pairwise=False)
k_xy = kernel(x=x[0::2, :], y=y[1::2, :], pairwise=False)
k_yz = kernel(x=y[0::2, :], y=x[1::2, :], pairwise=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is k_yz the paper notation? B/c it might be easier to follow by just calling it k_yx.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo, thanks for noticing, fixed.

@@ -68,16 +68,24 @@ def __init__(
def sigma(self) -> torch.Tensor:
return self.log_sigma.exp()

def forward(self, x: Union[np.ndarray, torch.Tensor], y: Union[np.ndarray, torch.Tensor],
infer_sigma: bool = False) -> torch.Tensor:
def forward(self, x: Union[np.ndarray, torch.Tensor],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpicking big time here, but let's try to keep same type of indentation as e.g. in the DeepKernel below.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.


x, y = torch.as_tensor(x), torch.as_tensor(y)
dist = distance.squared_pairwise_distance(x.flatten(1), y.flatten(1)) # [Nx, Ny]
if pairwise:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check my comment in distance.py which might make this if else redundant and reduce it to a kwarg of the distance function.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, now as part of the squared_pairwise_distance function argument.

if pairwise:
sigma = self.init_sigma_fn(x, y, dist)
else:
sigma = (.5 * dist.flatten().sort().values[dist.shape[0] // 2 - 1].unsqueeze(dim=-1)) ** .5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again I think we can avoid the hard-coding of this behaviour and fall back on self.init_sigma_fn but with the desired linear detector behaviour.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly tricky as the default init_sigma_fn is used by other detectors. Might be easier to keep the additional line here?

@arnaudvl
Copy link
Contributor

Left a number of comments related to the PyTorch implementation. Let's work through those first and then we can apply the desired changes to TensorFlow as well.

Copy link
Contributor

@ascillitoe ascillitoe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Requesting changes" to ensure we do not merge until #489 has been merged and predict updated in this PR.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@@ -20,14 +26,44 @@ def squared_pairwise_distance(x: tf.Tensor, y: tf.Tensor, a_min: float = 1e-30,
Lower bound to clip distance values.
a_max
Upper bound to clip distance values.

pairwise
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it a bit unclear to have a function named squared_pairwise_distance that optionally computes non-pairwise distances? Perhaps squared_pairwise_distance should be renamed squared_distance? Or the pairwise=False functionality separated out into a separate distance function?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking the same. It was previously a separate function and @arnaudvl was suggesting making the repeated parts minimal. Guess changing the function name across all related methods would be preferable.

@Srceh
Copy link
Author

Srceh commented May 18, 2022

Left a number of comments related to the PyTorch implementation. Let's work through those first and then we can apply the desired changes to TensorFlow as well.

TF version is also fixed for the above ones replied with "fixed".

@ascillitoe
Copy link
Contributor

@Srceh I've just merged the score/predict refactoring (#489). This will have introduced a few conflicts you need to resolve. It should simplify your life wrt to the implementation in this PR though!

@Srceh
Copy link
Author

Srceh commented May 18, 2022

@Srceh I've just merged the score/predict refactoring (#489). This will have introduced a few conflicts you need to resolve. It should simplify your life wrt to the implementation in this PR though!

Nice! will start working on that!

…IO#489).

Merge branch 'master' into linear_time_mmd

# Conflicts:
#	alibi_detect/cd/base.py
#	alibi_detect/cd/mmd.py
#	alibi_detect/cd/pytorch/mmd.py
#	alibi_detect/cd/tensorflow/mmd.py
@Srceh Srceh requested a review from ascillitoe May 21, 2022 19:08
@ascillitoe ascillitoe modified the milestones: v0.10.0, v0.10.1 Jul 12, 2022
@ascillitoe
Copy link
Contributor

@arnaudvl @Srceh I will resolve the conflicts for you once #537 has been merged.

@ascillitoe
Copy link
Contributor

@Srceh I have now merged in the v0.10.0 related changes from master. This primarily involved changes to the kwargs related to preprocessing, tweaking some tests, and adding your estimator kwarg to the MMDDrift pydantic models (see saving/schemas.py).

if self.device.type == 'cuda':
mmd2, mmd2_permuted = mmd2.cpu(), mmd2_permuted.cpu()
p_val = (mmd2 <= mmd2_permuted).float().mean()
# compute distance threshold
idx_threshold = int(self.p_val * len(mmd2_permuted))
distance_threshold = torch.sort(mmd2_permuted, descending=True).values[idx_threshold]
return p_val.numpy().item(), mmd2.numpy().item(), distance_threshold.numpy()


class LinearTimeMMDDriftTorch(BaseMMDDrift):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since these new subclasses don't make use of self.n_permutations (set in BaseMMDDrift), shall we set this to None? I had a moment of confusion when updating the tests since self.n_permuations == 100 when estimator == 'linear'.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. The default number of permutations then can be initialised in /cd/mmd.py when estimator is 'quad'.

self._detector = MMDDriftTF(*args, **kwargs) # type: ignore
elif estimator == 'linear':
kwargs.pop('n_permutations', None)
self._detector = LinearTimeMMDDriftTF(*args, **kwargs) # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the logic to set self._detector is located here, we should add additional tests to alibi_detect/cd/tests/test_mmd.py to check that the correct subclass is selected conditional on backend and estimator.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, will modify the tests.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simply rewrite the test to go through different backend and estimator options, should do the job.

@CLAassistant
Copy link

CLAassistant commented May 7, 2024

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you all sign our Contributor License Agreement before we can accept your contribution.
0 out of 2 committers have signed the CLA.

❌ ascillitoe
❌ Srceh
You have signed the CLA already but the status is still pending? Let us recheck it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Integrate linear time MMD detector
5 participants